- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.4k
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : remove all_pos_0, all_pos_1, all_seq_id from llama_batch #9745
Conversation
| I don't see a clear motivation for removing this. I believe that single sequence usage is by far the most common way llama.cpp is used, and removing this function will require most applications to add a lot of boilerplate. We should aim to make the llama.cpp API as simple as possible to use. | 
| My main motivation for this PR is that instead of having an API call solely for keeping backward-compatibility, we could keep it as an utility, not a core API. Second motivation is that Keeping these backward-compat struct member makes the code inside   | 
| 
 I think in this use case, simple specify  So if we really want to simplify the usage for end user, we could allow user to only set  Even more simple,  | 
| There is a lot we could do to simplify the  
 | 
| 
 Let me clarify a bit more, what I mean was that in all examples, we always set: 
 So I assume that 99% of the case, if user want to work with single-sequence (the most basic usage), then  
 The problem with such change is that even without touching It seems OK for me to keep  In any cases, I still strongly prefer to remove  | 
| Sounds goods to me. Other than causing an ABI break, removing  | 
697a3f9    to
    1c48616      
    Compare
  
    | // - pos : the positions of the respective token in the sequence | ||
| // (if set to NULL, the token position will be tracked automatically by llama_decode) | ||
| // - seq_id : the sequence to which the respective token belongs | ||
| // (if set to NULL, the sequence ID will be assumed to be 0) | ||
| // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output | ||
| // (if set to NULL, only the logits for last token will be returned) | ||
| // | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slaren @ggerganov I updated the behavior of llama_batch to adapt to the removal of all_pos_0, all_pos_1, all_seq_id, please let me know what you think about this implementation. Thank you!
| result2 += next_token_str; | ||
|  | ||
| if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { | ||
| if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) { | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will generate a batch for seq_id == 0 and it needs to be seq_id == 1
make -j && ./llama-save-load-state -m ${some_model}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting that! Fixed in 6395174
        
          
                examples/perplexity/perplexity.cpp
              
                Outdated
          
        
      | const int batch_start = start + j * n_batch; | ||
| const int batch_size = std::min(end - batch_start, n_batch); | ||
|  | ||
| llama_batch batch = llama_batch_init(batch_size, 0, 1); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the llama_batch outside the loop and reuse it. Maybe utilize the common_batch_ API to make it little less cumbersome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
        
          
                src/llama.cpp
              
                Outdated
          
        
      | batch.n_seq_id = n_seq_id.data(); | ||
| } | ||
| if (!batch.seq_id) { | ||
| seq_id.resize(batch.n_tokens); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this also NULL terminated for consistency (see llama_batch_init):
| seq_id.resize(batch.n_tokens); | |
| seq_id.resize(batch.n_tokens + 1); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 7264596
        
          
                examples/infill/infill.cpp
              
                Outdated
          
        
      |  | ||
| llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); | ||
| llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); | ||
| llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard); | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small explanation for what's happening: We suppose to shift all tokens from n_keep + n_discard + 1, so the end of must be n_past + 1 (or we can simply set it to -1, which means [p0, inf))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I don't think n_past + 1 is needed here. There shouldn't be a token with pos == n_past in the KV cache.
But yes, using either n_past or -1 would achieve the same thing. Think using n_past is more illustrative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok thanks, I figured out that I counted the token from 1, not from 0. I fixed that in 5d99ae4
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
…l-org#9745) * refactor llama_batch_get_one * adapt all examples * fix simple.cpp * fix llama_bench * fix * fix context shifting * free batch before return * use common_batch_add, reuse llama_batch in loop * null terminated seq_id list * fix save-load-state example * fix perplexity * correct token pos in llama_batch_allocr
Motivation
While working on the ability to add both embeddings and tokens to the same batch, I noticed that the old API for
llama_batch, namelyall_pos_0,all_post_1andall_seq_idhas been there for quite a long time.Migration guide
The recommended way is to use
llama_batch_initandllama_batch_free:If the binary is linked against
common, you can use some helper functions:common_batch_addto add a new token into the batchcommon_batch_clearto remove all tokens from the batchIf your use case is using single sequence, then you can adapt to the new call signature of
llama_batch_get_one(although, this is not recommended):The position of tokens will be tracked automatically by
llama_decode. For example, if the first time, you callllama_decodeon a batch of 10 tokens, then the next timellama_decodewill start decoding from position 11.